{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Code for a painless q-learning tutorial\n",
    "\n",
    "- Refer to <http://mnemstudio.org/path-finding-q-learning-tutorial.htm>\n",
    "- Build a simple AI based on `numpy`.\n",
    "- Write this just for study\n",
    "- Author: [JasonQSY](https://github.com/JasonQSY)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# we just need numpy\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# initial\n",
    "# q is the tabular representation for q func.\n",
    "q = np.matrix(np.zeros([6, 6]))\n",
    "\n",
    "# r is the tabular representation for rewards.\n",
    "# r is predefined and remains unchanged.\n",
    "r = np.matrix([[-1, -1, -1, -1,  0,  -1], \n",
    "               [-1, -1, -1,  0, -1, 100], \n",
    "               [-1, -1, -1,  0, -1,  -1], \n",
    "               [-1,  0,  0, -1,  0,  -1], \n",
    "               [ 0, -1, -1,  0, -1, 100], \n",
    "               [-1,  0, -1, -1,  0, 100]])\n",
    "\n",
    "# hyperparameter\n",
    "gamma = 0.8\n",
    "epsilon = 0.4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "There is a difference between the code here and the guide. Here we use epsilon greedy, i.e. with probability epsilon we choose a random action, otherwise we choose the greedy action to maximize local Q value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------------------------\n",
      "Training episode: 0\n",
      "[[   0.    0.    0.    0.    0.    0.]\n",
      " [   0.    0.    0.    0.    0.    0.]\n",
      " [   0.    0.    0.    0.    0.    0.]\n",
      " [   0.    0.    0.    0.    0.    0.]\n",
      " [   0.    0.    0.    0.    0.  100.]\n",
      " [   0.    0.    0.    0.    0.    0.]]\n",
      "------------------------------------------------\n",
      "Training episode: 10\n",
      "[[   0.    0.    0.    0.   80.    0.]\n",
      " [   0.    0.    0.   64.    0.  100.]\n",
      " [   0.    0.    0.   64.    0.    0.]\n",
      " [   0.   80.    0.    0.    0.    0.]\n",
      " [   0.    0.    0.   64.    0.  100.]\n",
      " [   0.    0.    0.    0.    0.    0.]]\n",
      "------------------------------------------------\n",
      "Training episode: 20\n",
      "[[   0.     0.     0.     0.    80.     0. ]\n",
      " [   0.     0.     0.    64.     0.   100. ]\n",
      " [   0.     0.     0.    64.     0.     0. ]\n",
      " [   0.    80.    51.2    0.    80.     0. ]\n",
      " [  64.     0.     0.    64.     0.   100. ]\n",
      " [   0.     0.     0.     0.     0.     0. ]]\n",
      "------------------------------------------------\n",
      "Training episode: 30\n",
      "[[   0.     0.     0.     0.    80.     0. ]\n",
      " [   0.     0.     0.    64.     0.   100. ]\n",
      " [   0.     0.     0.    64.     0.     0. ]\n",
      " [   0.    80.    51.2    0.    80.     0. ]\n",
      " [  64.     0.     0.    64.     0.   100. ]\n",
      " [   0.     0.     0.     0.     0.     0. ]]\n",
      "------------------------------------------------\n",
      "Training episode: 40\n",
      "[[   0.     0.     0.     0.    80.     0. ]\n",
      " [   0.     0.     0.    64.     0.   100. ]\n",
      " [   0.     0.     0.    64.     0.     0. ]\n",
      " [   0.    80.    51.2    0.    80.     0. ]\n",
      " [  64.     0.     0.    64.     0.   100. ]\n",
      " [   0.     0.     0.     0.     0.     0. ]]\n",
      "------------------------------------------------\n",
      "Training episode: 50\n",
      "[[   0.     0.     0.     0.    80.     0. ]\n",
      " [   0.     0.     0.    64.     0.   100. ]\n",
      " [   0.     0.     0.    64.     0.     0. ]\n",
      " [   0.    80.    51.2    0.    80.     0. ]\n",
      " [  64.     0.     0.    64.     0.   100. ]\n",
      " [   0.     0.     0.     0.     0.     0. ]]\n",
      "------------------------------------------------\n",
      "Training episode: 60\n",
      "[[   0.     0.     0.     0.    80.     0. ]\n",
      " [   0.     0.     0.    64.     0.   100. ]\n",
      " [   0.     0.     0.    64.     0.     0. ]\n",
      " [   0.    80.    51.2    0.    80.     0. ]\n",
      " [  64.     0.     0.    64.     0.   100. ]\n",
      " [   0.     0.     0.     0.     0.     0. ]]\n",
      "------------------------------------------------\n",
      "Training episode: 70\n",
      "[[   0.     0.     0.     0.    80.     0. ]\n",
      " [   0.     0.     0.    64.     0.   100. ]\n",
      " [   0.     0.     0.    64.     0.     0. ]\n",
      " [   0.    80.    51.2    0.    80.     0. ]\n",
      " [  64.     0.     0.    64.     0.   100. ]\n",
      " [   0.     0.     0.     0.     0.     0. ]]\n",
      "------------------------------------------------\n",
      "Training episode: 80\n",
      "[[   0.     0.     0.     0.    80.     0. ]\n",
      " [   0.     0.     0.    64.     0.   100. ]\n",
      " [   0.     0.     0.    64.     0.     0. ]\n",
      " [   0.    80.    51.2    0.    80.     0. ]\n",
      " [  64.     0.     0.    64.     0.   100. ]\n",
      " [   0.     0.     0.     0.     0.     0. ]]\n",
      "------------------------------------------------\n",
      "Training episode: 90\n",
      "[[   0.     0.     0.     0.    80.     0. ]\n",
      " [   0.     0.     0.    64.     0.   100. ]\n",
      " [   0.     0.     0.    64.     0.     0. ]\n",
      " [   0.    80.    51.2    0.    80.     0. ]\n",
      " [  64.     0.     0.    64.     0.   100. ]\n",
      " [   0.     0.     0.     0.     0.     0. ]]\n",
      "------------------------------------------------\n",
      "Training episode: 100\n",
      "[[   0.     0.     0.     0.    80.     0. ]\n",
      " [   0.     0.     0.    64.     0.   100. ]\n",
      " [   0.     0.     0.    64.     0.     0. ]\n",
      " [   0.    80.    51.2    0.    80.     0. ]\n",
      " [  64.     0.     0.    64.     0.   100. ]\n",
      " [   0.     0.     0.     0.     0.     0. ]]\n"
     ]
    }
   ],
   "source": [
    "# the main training loop\n",
    "for episode in range(101):\n",
    "    # random initial state\n",
    "    state = np.random.randint(0, 6)\n",
    "\n",
    "    while (state != 5): # stop only when state == TERMINAL\n",
    "        # Filter feasible actions.\n",
    "        # Even in random case, we cannot choose actions whose r[state, action] = -1.\n",
    "        # It's not about reinforcement learning. These actions are not feasible physically.\n",
    "        possible_actions = []\n",
    "        possible_q = []\n",
    "        for action in range(6):\n",
    "            if r[state, action] >= 0:\n",
    "                possible_actions.append(action)\n",
    "                possible_q.append(q[state, action])\n",
    "                \n",
    "        # Step next state, here we use epsilon-greedy algorithm.\n",
    "        action = -1\n",
    "        if np.random.random() < epsilon:\n",
    "            # choose random action\n",
    "            action = possible_actions[np.random.randint(0, len(possible_actions))]\n",
    "        else:\n",
    "            # greedy\n",
    "            action = possible_actions[np.argmax(possible_q)]\n",
    "        \n",
    "        # Update Q value\n",
    "        q[state, action] = r[state, action] + gamma * q[action].max()\n",
    "        \n",
    "        # Go to the next state\n",
    "        state = action\n",
    "        \n",
    "    # Display training progress\n",
    "    if episode % 10 == 0:\n",
    "        print(\"------------------------------------------------\")\n",
    "        print(\"Training episode: %d\" % episode)\n",
    "        print(q)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the Q matrix converges and it is the same as the matrix shown in the guide. The only difference is that all the entries are 0 for row 6. That's because we do not update the q matrix when state == 5. I have no idea why the row 6 in the guide has postive values. Anyway, let's verify the result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "episode: 0\n",
      "the robot borns in 4.\n",
      "the robot goes to 5.\n",
      "episode: 1\n",
      "the robot borns in 2.\n",
      "the robot goes to 3.\n",
      "the robot goes to 1.\n",
      "the robot goes to 5.\n",
      "episode: 2\n",
      "the robot borns in 5.\n",
      "episode: 3\n",
      "the robot borns in 0.\n",
      "the robot goes to 4.\n",
      "the robot goes to 5.\n",
      "episode: 4\n",
      "the robot borns in 0.\n",
      "the robot goes to 4.\n",
      "the robot goes to 5.\n",
      "episode: 5\n",
      "the robot borns in 0.\n",
      "the robot goes to 4.\n",
      "the robot goes to 5.\n",
      "episode: 6\n",
      "the robot borns in 0.\n",
      "the robot goes to 4.\n",
      "the robot goes to 5.\n",
      "episode: 7\n",
      "the robot borns in 1.\n",
      "the robot goes to 5.\n",
      "episode: 8\n",
      "the robot borns in 2.\n",
      "the robot goes to 3.\n",
      "the robot goes to 1.\n",
      "the robot goes to 5.\n",
      "episode: 9\n",
      "the robot borns in 4.\n",
      "the robot goes to 5.\n"
     ]
    }
   ],
   "source": [
    "# verify\n",
    "for i in range(10):\n",
    "    # one episode\n",
    "    print(\"episode: %d\" % i) \n",
    "    \n",
    "    # random initial state\n",
    "    state = np.random.randint(0, 6)\n",
    "    print(\"the robot borns in \" + str(state) + \".\")\n",
    "    for _ in range(20): # prevent endless loop\n",
    "        if state == 5:\n",
    "            break\n",
    "        action = np.argmax(q[state])\n",
    "        print(\"the robot goes to %d.\" % action)\n",
    "        state = action"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "Now the robot shows his ability to find the correct path rapidly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
